import argparse
from gqe import GQE
from q2b import Q2B
from q2p import Q2P
from line_predictor_model.model import CQDTransEA
from line_predictor_model.regularizers import Regularizer, N2,N3


import torch
from dataloader import TrainDataset, ValidDataset, TestDataset, SingledirectionalOneShotIterator, separate_query_dict
from dataloader import baseline_abstraction, abstraction
from torch.utils.data import DataLoader
import os
from tqdm import tqdm
import numpy as np
from datetime import datetime
from tensorboardX import SummaryWriter
import gc
import pickle
from torch.optim.lr_scheduler import LambdaLR
import json
import networkx as nx
import collections

from numeral_encoder import PositionalEncoder, DICE, GMM_Prototype, DigitRNN
import time

device = torch.device("cuda:5" if torch.cuda.is_available() else "cpu")

torch.autograd.set_detect_anomaly(True)

def log_aggregation(list_of_logs):
    all_log = {}

    for __log in list_of_logs:
        # Sometimes the number of answers are 0, so we need to remove all the keys with 0 values
        # The average is taken over all queries, instead of over all answers, as is done following previous work. 

        if "inf_num_answers" in __log and __log["inf_num_answers"] == 0:
            raise ValueError("no value error!")
            
        for __key, __value in __log.items():
            if "num_answers" in __key:
                continue
            else:
                if __key in all_log:
                    all_log[__key].append(__value)
                else:
                    all_log[__key] = [__value]

    average_log = {_key: np.mean(_value) for _key, _value in all_log.items()}

    return average_log
def log_mu_prediction(list_of_logs):
    all_log=collections.defaultdict(list)
    for item in list_of_logs:
        all_log["squared_error_mu_prediction"].append(item["squared_error_mu_prediction"])
        all_log["squared_error_mu_ave"].append(item["squared_error_mu_ave"])
        all_log["absolute_error_mu_prediction"].append(item["absolute_error_mu_prediction"])
        all_log["absolute_error_mu_ave"].append(item["absolute_error_mu_ave"])
    average_log = {_key: np.mean(_value) for _key, _value in all_log.items()}
    return average_log

if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='The training and evaluation script for the models')

    parser.add_argument('--query_data_dir', default="sampled_data_small", help="The path to the sampled queries.")
    parser.add_argument('--kg_data_dir', default="KG_data/", help="The path the original kg data")

    parser.add_argument("--test_query_dir", required=True)

    parser.add_argument('--log_steps', default=50000, type=int, help='train log every xx steps')
    parser.add_argument('-dn', '--data_name', type=str, required=True)
    parser.add_argument('-b', '--batch_size', default=64, type=int)

    parser.add_argument('-d', '--entity_space_dim', default=400, type=int)
    parser.add_argument('-lr', '--learning_rate', default=0.001, type=float)
    parser.add_argument('-wc', '--weight_decay', default=0.0000, type=float)

    parser.add_argument("--optimizer", default="adam", type=str)
    parser.add_argument("--dropout_rate", default=0.1, type=float)
    parser.add_argument('-ls', "--label_smoothing", default=0.1, type=float)
    parser.add_argument('-nls', "--numerical_label_smoothing", default=0.3, type=float)

    parser.add_argument('--max_train_step', default=540000, type=int)

    parser.add_argument("--warm_up_steps", default=1000, type=int)

    parser.add_argument("-m", "--model", required=True)

   
    parser.add_argument("--experiment_number", type=str, default="e34")

    parser.add_argument("--numeral_encoder", default="dice", type=str)
    parser.add_argument("--mixed_value_reprerentation", action="store_true")

    parser.add_argument("--small", action="store_true")

    parser.add_argument("--timing", action="store_true")

    parser.add_argument("--typed", action="store_true")

    parser.add_argument("--quantile", action="store_true")

    parser.add_argument("-ga", "--gradient_accumulation_steps", type=int, default=1)

    args = parser.parse_args()
#

    data_name = args.data_name

    #读取测试数据集
    with open(args.test_query_dir,"rb") as file:
        test_data_dict=pickle.load(file)

    
    #读取训练图和测试图
    data_dir =  args.data_name
    print("Load Train Graph " + data_dir)
    train_path = "/home/pxs/projects/paper_nrn_new/model_proposed_accelerate_numpre_task2/mu_pre_model_traindata/" + data_dir + "_train_with_units.pkl"
    train_graph = nx.read_gpickle(train_path)

    print("Load Test Graph " + data_dir)
    test_path = "/home/pxs/projects/paper_nrn_new/model_proposed_accelerate_numpre_task2/mu_pre_model_traindata/" + data_dir + "_test_with_units.pkl"
    test_graph = nx.read_gpickle(test_path)


    all_values = []
    for u in test_graph.nodes():
        if isinstance(u, tuple):
            all_values.append(u[0])

    train_values = []
    for u in train_graph.nodes():
        if isinstance(u, tuple):
            train_values.append(u[0])


    all_typed_values = {}
    for u in test_graph.nodes():
        if isinstance(u, tuple):
            if u[1] not in all_typed_values:
                all_typed_values[u[1]] = []
            all_typed_values[u[1]].append(u[0])

    num_min_max={}
    for i in range(len(all_typed_values)):
        numlist=all_typed_values[i]
        min=np.min(numlist)
        max=np.max(numlist)
        #min=np.percentile(numlist, 3)
        #max=np.percentile(numlist,97)
        
        num_min_max[i]={}
        num_min_max[i]["min"]=min
        num_min_max[i]["max"]=max


    EncoderType = PositionalEncoder

    if args.numeral_encoder == "positional":
        EncoderType = PositionalEncoder
    
    elif args.numeral_encoder == "dice":
        EncoderType = DICE

    elif args.numeral_encoder == "gmm":
        EncoderType = GMM_Prototype
    
    else:
        raise ValueError("Invalid numeral encoder")


    train_typed_values = {}
    for u in train_graph.nodes():
        if isinstance(u, tuple):
            if u[1] not in train_typed_values:
                train_typed_values[u[1]] = []
            train_typed_values[u[1]].append(u[0])
    
    
    if args.typed:
        encoder_list  = []

        if args.quantile:
            for i in range(len(all_typed_values)):
                positional_encoder_log = EncoderType(device,output_size=300, train_values=train_typed_values[i], all_values=all_typed_values[i], scaler="quantile")
                encoder_list.append(positional_encoder_log)
        

        else:

            for i in range(len(all_typed_values)):
                positional_encoder_log = EncoderType(device,output_size=300, train_values=train_typed_values[i], all_values=all_typed_values[i])
                encoder_list.append(positional_encoder_log)
        
    else:
        encoder = EncoderType(device,output_size=300, train_values=train_values, all_values=all_values)
        encoder_list = [encoder]



    entity_counter = 0
    value_counter = 0
    tuple_all_values = []

    for u in test_graph.nodes():
        if isinstance(u, tuple):
            value_counter += 1
            tuple_all_values.append(u)
        elif isinstance(u, str):
            entity_counter += 1

    value_vocab = dict(zip(tuple_all_values, range(0, len(tuple_all_values))))


    relation_edges_list = []
    attribute_edges_list = []
    reverse_attribute_edges_list = []
    numerical_edges_list = []


    for u, v, a in test_graph.edges(data=True):
        if isinstance(u, tuple) and isinstance(v, tuple):
            for key, value in a.items():
                numerical_edges_list.append(key)
        elif isinstance(u, tuple):
            for key, value in a.items():
                reverse_attribute_edges_list.append(key)
        elif isinstance(v, tuple):
            for key, value in a.items():
                attribute_edges_list.append(key)
        elif isinstance(u, str) and isinstance(v, str):
            for key, value in a.items():
                relation_edges_list.append(key)

    relation_edges_list = list(set(relation_edges_list))
    attribute_edges_list = list(set(attribute_edges_list))
    reverse_attribute_edges_list = list(set(reverse_attribute_edges_list))
    numerical_edges_list = list(set(numerical_edges_list))

    nentity = entity_counter
    nvalue = value_counter



    nrelation = len(relation_edges_list)
    nattribute = len(attribute_edges_list)
    nnumerical_proj = len(numerical_edges_list)

    

    batch_size = args.batch_size
    
    print("====== Create Testing Dataloader ======")
    test_loaders = {}
    for query_type, query_answer_dict in test_data_dict.items():
        sub_query_types_dicts = separate_query_dict(query_answer_dict, nentity, nrelation)

        for sub_query_type, sub_query_types_dict in sub_query_types_dicts.items():
            new_iterator = DataLoader(
                TestDataset(nentity, nrelation, sub_query_types_dict,
                            baseline=False, nattribute=nattribute, value_vocab=value_vocab),
                batch_size=batch_size,
                shuffle=True,
                collate_fn=TestDataset.collate_fn
            )
            test_loaders[sub_query_type] = new_iterator
    

    if args.model == "q2p":
        # model = Q2P(num_entities=nentity + nvalue,
        #             num_relations=nrelation + nattribute*2 + nnumerical_proj,
        #             embedding_size=300)

        model = Q2P(num_entities=nentity,
                    num_relations=nrelation,
                    embedding_size=300,
                    num_attributes=nattribute,
                    num_numrical_proj=nnumerical_proj,
                    value_vocab=value_vocab,
                    number_encoder_list=encoder_list,
                    mixed_value_reprerentation=True,
                    label_smoothing=args.label_smoothing,
                    numerical_label_smoothing=args.numerical_label_smoothing,)
    elif args.model == "gqe":

        model = GQE(num_entities=nentity,
                    num_relations=nrelation,
                    embedding_size=300,
                    num_attributes=nattribute,
                    num_numrical_proj=nnumerical_proj,
                    value_vocab=value_vocab,
                    number_encoder_list=encoder_list,
                    mixed_value_reprerentation=True,
                    label_smoothing=args.label_smoothing,
                    numerical_label_smoothing=args.numerical_label_smoothing,)
                    
    
    elif args.model == "q2b":

        model = Q2B(num_entities=nentity,
                    num_relations=nrelation,
                    embedding_size=300,
                    num_attributes=nattribute,
                    num_numrical_proj=nnumerical_proj,
                    value_vocab=value_vocab,
                    number_encoder_list=encoder_list,
                    mixed_value_reprerentation=True,
                    label_smoothing=args.label_smoothing,
                    numerical_label_smoothing=args.numerical_label_smoothing,)
    else:
        raise NotImplementedError


    if torch.cuda.is_available():
        model = model.to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    
    #torch.save(model.state_dict(), '/home/pxs/projects/paper_nrn_new/model_proposed_accelerate_numpre_task2/mu_pre_model_parameters/model_parameters.pth')
    
    
    if torch.cuda.is_available():
        map_location = device
    else:
        map_location = device
    
    train_model=False
    if train_model:
        raise ValueError("no Train!!")
    else:
        model.load_state_dict(torch.load("/home/pxs/projects/paper_nrn_new/model_proposed_accelerate_numpre_task2/mu_pre_model_parameters/model_parameters.pth",map_location=map_location))
        #模型测试
        model.eval()
        
        #加载链路预测器
        meta_data_path="/home/pxs/projects/paper_nrn_new/model_proposed_accelerate_numpre_task2/line_predictor_model/model-metadata-1717256032.json"
        with open(meta_data_path,"r") as file:
            meta_data=json.load(file)
        line_predictor_param_path="/home/pxs/projects/paper_nrn_new/model_proposed_accelerate_numpre_task2/line_predictor_model/epoch_100_1717256032"
        
        checkpoint=torch.load(line_predictor_param_path,map_location=map_location)
        line_predictor=CQDTransEA(
            checkpoint["entity_num"],
            checkpoint["rank"],
            checkpoint["att_num"],
            checkpoint["relation_num"],
            checkpoint["p_norm"],
            checkpoint["use_attributes"],
            checkpoint["do_sigmoid"]
        ).to(map_location)
        line_predictor.load_state_dict(checkpoint["model_state_dict"])
        line_predictor.eval()

        
        
        all_type_entity_logs=[]
        all_type_abstracted_entity_logs = {}

        for task_name, loader in test_loaders.items():
            print(task_name)
            for batched_query, easy_answers, hard_answers, query_attribution in loader:
                query_embedding = model(batched_query)
                batch_logs= model.task2_evaluate_generalization(line_predictor,query_embedding, easy_answers, hard_answers, query_attribution)

                
                all_type_entity_logs.extend(batch_logs)
                abstract_query_type = loader.dataset.query_type
                if abstract_query_type in all_type_abstracted_entity_logs:
                    all_type_abstracted_entity_logs[abstract_query_type].extend(batch_logs)
                else:
                    all_type_abstracted_entity_logs[abstract_query_type] = []
                    all_type_abstracted_entity_logs[abstract_query_type].extend(batch_logs)

        #记录结果
        file=open("/home/pxs/projects/paper_nrn_new/model_proposed_accelerate_numpre_task2/outcome/outcome.txt","w") 
        
        all_type_entity_result = log_aggregation(all_type_entity_logs) 
        file.write("test_all_type:\n")  
        for key, value in all_type_entity_result.items():
            file.write(f"{key}:\t{value:.4f}\n")

        file.write("\n\ntest_each_type_outcome:\n")  
        for key, value in all_type_abstracted_entity_logs.items():
            file.write(f"\n{key}:\n")
            aggregated_value = log_aggregation(value)
            for metric, metric_value in aggregated_value.items():
                file.write(f"{metric}:\t{metric_value:.4f}\n")
        file.close()

    #torch.save(model.state_dict(), '/home/pxs/projects/paper_nrn_new/model_proposed_accelerate_numpre_task2/mu_pre_model_parameters/model_parameters.pth')



















